#!/usr/bin/env python3
#
# Runtime test for the INET Framework.
#
# Accepts one or more CSV files with 4 columns: working directory,
# options to opp_run, simulation time limit, expected runtime.
# The program runs the INET simulations in the CSV files, and
# reports runtime differences. To facilitate test suite maintenance,
# the program also creates a new file (or files) with the updated runtimes.
#
# Implementation is based on Python's unit testing library, so it can be
# integrated into larger test suites with minimal effort
#
# Author: Andras Varga
#

import argparse
import numpy
import copy
import csv
import glob
import multiprocessing
import os
import re
import subprocess
import sys
import threading
import time
import unittest
from distutils import spawn
from StringIO import StringIO


inetRoot = os.path.abspath("../..")
sep = ";" if sys.platform == 'win32' else ':'
nedPath = inetRoot + "/src" + sep + inetRoot + "/examples" + sep + inetRoot + "/showcases" + sep + inetRoot + "/tutorials" + sep + inetRoot + "/tests/networks"
inetLib = inetRoot + "/src/INET"
cpuTimeLimit = "30000s"
logFile = "test.out"
extraOppRunArgs = ""
exitCode = 0

if spawn.find_executable("opp_run_release"):
	executable = spawn.find_executable("opp_run_release")
elif spawn.find_executable("opp_run"):
	executable = spawn.find_executable("opp_run")
elif spawn.find_executable("opp_run_dbg"):
	executable = spawn.find_executable("opp_run_dbg")

class runtimeTestCaseGenerator():
    fileToSimulationsMap = {}
    def generateFromCSV(self, csvFileList, filterRegexList, excludeFilterRegexList, repeat):
        testcases = []
        for csvFile in csvFileList:
            simulations = self.parseSimulationsTable(csvFile)
            self.fileToSimulationsMap[csvFile] = simulations
            testcases.extend(self.generateFromDictList(simulations, filterRegexList, excludeFilterRegexList, repeat))
        return testcases

    def generateFromDictList(self, simulations, filterRegexList, excludeFilterRegexList, repeat):
        class StoreMeasuredRuntimeCallback:
            def __init__(self, simulation):
                self.simulation = simulation
            def __call__(self, measuredRuntime):
                self.simulation['measuredRuntime'] = measuredRuntime

        class StoreExitcodeCallback:
            def __init__(self, simulation):
                self.simulation = simulation
            def __call__(self, exitcode):
                self.simulation['exitcode'] = exitcode

        testcases = []
        for simulation in simulations:
            title = simulation['wd'] + " " + simulation['args']
            if not filterRegexList or ['x' for regex in filterRegexList if re.search(regex, title)]: # if any regex matches title
                if not excludeFilterRegexList or not ['x' for regex in excludeFilterRegexList if re.search(regex, title)]: # if NO exclude-regex matches title
                    testcases.append(RuntimeTestCase(title, simulation['file'], simulation['wd'], simulation['args'],
                            simulation['simtimelimit'], simulation['expectedRuntime'], StoreMeasuredRuntimeCallback(simulation), StoreExitcodeCallback(simulation), repeat))
        return testcases

    def commentRemover(self, csvData):
        p = re.compile(' *#.*$')
        for line in csvData:
            yield p.sub('',line)

    # parse the CSV into a list of dicts
    def parseSimulationsTable(self, csvFile):
        simulations = []
        f = open(csvFile, 'rb')
        csvReader = csv.reader(self.commentRemover(f), delimiter=',', quotechar='"', skipinitialspace=True)
        for fields in csvReader:
            if len(fields) == 0:
                continue        # empty line
            if len(fields) != 4:
                raise Exception("Line " + str(csvReader.line_num) + " must contain 4 items, but contains " + str(len(fields)) + ": " + '"' + '", "'.join(fields) + '"')
            simulations.append({'file': csvFile, 'line' : csvReader.line_num, 'wd': fields[0], 'args': fields[1], 'simtimelimit': fields[2], 'expectedRuntime': float(fields[3])})
        f.close()
        return simulations

    def writeUpdatedCSVFiles(self):
        for csvFile, simulations in self.fileToSimulationsMap.iteritems():
            updatedContents = self.formatUpdatedSimulationsTable(csvFile, simulations)
            if updatedContents:
                updatedFile = csvFile + ".UPDATED"
                ff = open(updatedFile, 'w')
                ff.write(updatedContents)
                ff.close()
                print "Check " + updatedFile + " for updated runtimes"

    def writeFailedCSVFiles(self):
        for csvFile, simulations in self.fileToSimulationsMap.iteritems():
            failedContents = self.formatFailedSimulationsTable(csvFile, simulations)
            if failedContents:
                failedFile = csvFile + ".FAILED"
                ff = open(failedFile, 'w')
                ff.write(failedContents)
                ff.close()
                print "Check " + failedFile + " for failures"

    def writeErrorCSVFiles(self):
        for csvFile, simulations in self.fileToSimulationsMap.iteritems():
            errorContents = self.formatErrorSimulationsTable(csvFile, simulations)
            if errorContents:
                errorFile = csvFile + ".ERROR"
                ff = open(errorFile, 'w')
                ff.write(errorContents)
                ff.close()
                print "Check " + errorFile + " for errors"

    def escape(self, str):
        if re.search(r'[\r\n\",]', str):
            str = '"' + re.sub('"','""',str) + '"'
        return str

    def formatUpdatedSimulationsTable(self, csvFile, simulations):
        # if there is a measured runtime, print that instead of existing one
        ff = open(csvFile, 'r')
        lines = ff.readlines()
        ff.close()
        lines.insert(0, '')     # csv line count is 1..n; insert an empty item --> lines[1] is the first line

        containsMeasuredRuntime = False
        for simulation in simulations:
            if 'measuredRuntime' in simulation:
                expectedRuntime = simulation['expectedRuntime']
                measuredRuntime = simulation['measuredRuntime']

                containsMeasuredRuntime = True
                line = simulation['line']
                pattern = "\\b" + str(expectedRuntime) + "\\b"
                (newLine, cnt) = re.subn(pattern, str(measuredRuntime), lines[line])
                if (cnt == 1):
                    lines[line] = newLine
                else:
                    print "ERROR: Cannot replace runtime '%s' to '%s' at '%s' line %d:\n     %s" % (expectedRuntime, measuredRuntime, csvFile, line, lines[line])
        return ''.join(lines) if containsMeasuredRuntime else None

    def formatFailedSimulationsTable(self, csvFile, simulations):
        ff = open(csvFile, 'r')
        lines = ff.readlines()
        ff.close()
        lines.insert(0, '')     # csv line count is 1..n; insert an empty item --> lines[1] is the first line
        result = []

        containsFailures = False
        for simulation in simulations:
            if 'measuredRuntime' in simulation:
                expectedRuntime = simulation['expectedRuntime']
                measuredRuntime = simulation['measuredRuntime']
                if expectedRuntime != measuredRuntime:
                    if not containsFailures:
                        containsFailures = True
                        result.append("# Failures:\n");
                    result.append(lines[simulation['line']])
        return ''.join(result) if containsFailures else None

    def formatErrorSimulationsTable(self, csvFile, simulations):
        ff = open(csvFile, 'r')
        lines = ff.readlines()
        ff.close()
        lines.insert(0, '')     # csv line count is 1..n; insert an empty item --> lines[1] is the first line
        result = []

        containsErrors = False
        for simulation in simulations:
            if 'exitcode' in simulation and simulation['exitcode'] != 0:
                if not containsErrors:
                    containsErrors = True
                    result.append("# Errors:\n");
                result.append(lines[simulation['line']])

        return ''.join(result) if containsErrors else None


class SimulationResult:
    def __init__(self, command, workingdir, exitcode, errorMsg=None, isRuntimeOK=None,
            measuredRuntime=None, simulatedTime=None, numEvents=None, elapsedTime=None, cpuTimeLimitReached=None):
        self.command = command
        self.workingdir = workingdir
        self.exitcode = exitcode
        self.errorMsg = errorMsg
        self.isRuntimeOK = isRuntimeOK
        self.measuredRuntime = measuredRuntime
        self.simulatedTime = simulatedTime
        self.numEvents = numEvents
        self.elapsedTime = elapsedTime
        self.cpuTimeLimitReached = cpuTimeLimitReached


class SimulationTestCase(unittest.TestCase):
    def runSimulation(self, title, command, workingdir, resultdir):
        global logFile
        ensure_dir(workingdir + "/results")

        # run the program and log the output
        t0 = time.time()

        (exitcode, out) = self.runProgram(command, workingdir)

        elapsedTime = time.time() - t0
        FILE = open(logFile, "a")
        FILE.write("------------------------------------------------------\n"
                 + "Running: " + title + "\n\n"
                 + "$ cd " + workingdir + "\n"
                 + "$ " + command + "\n\n"
                 + out.strip() + "\n\n"
                 + "Exit code: " + str(exitcode) + "\n"
                 + "Elapsed time:  " + str(round(elapsedTime,2)) + "s\n\n")
        FILE.close()

        FILE = open(resultdir + "/test.out", "w")
        FILE.write("------------------------------------------------------\n"
                 + "Running: " + title + "\n\n"
                 + "$ cd " + workingdir + "\n"
                 + "$ " + command + "\n\n"
                 + out.strip() + "\n\n"
                 + "Exit code: " + str(exitcode) + "\n"
                 + "Elapsed time:  " + str(round(elapsedTime,2)) + "s\n\n")
        FILE.close()

        measuredRuntime = float(re.search("Simulation Total Runtime: ([0-9]+\.[0-9]*)", out).group(1))
        result = SimulationResult(command, workingdir, exitcode, measuredRuntime=measuredRuntime, elapsedTime=elapsedTime)

        # process error messages
        errorLines = re.findall("<!>.*", out, re.M)
        errorMsg = ""
        for err in errorLines:
            err = err.strip()
            errorMsg += "\n" + err
            if re.search("CPU time limit reached", err):
                result.cpuTimeLimitReached = True
            m = re.search("at event #([0-9]+), t=([0-9]*(\\.[0-9]+)?)", err)
            if m:
                result.numEvents = int(m.group(1))
                result.simulatedTime = float(m.group(2))

        result.errormsg = errorMsg.strip()
        return result

    def runProgram(self, command, workingdir):
        process = subprocess.Popen("sudo chrt -f 1 " + command, shell=True, cwd=workingdir, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        out = process.communicate()[0]
        out = re.sub("\r", "", out)
        return (process.returncode, out)


class RuntimeTestCase(SimulationTestCase):
    def __init__(self, title, csvFile, wd, args, simtimelimit, expectedRuntime, storeMeasuredRuntimeCallback, storeExitcodeCallback, repeat):
        SimulationTestCase.__init__(self)
        self.title = title
        self.csvFile = csvFile
        self.wd = wd
        self.args = args
        self.simtimelimit = simtimelimit
        self.expectedRuntime = expectedRuntime
        self.storeMeasuredRuntimeCallback = storeMeasuredRuntimeCallback
        self.storeExitcodeCallback = storeExitcodeCallback
        self.repeat = repeat

    def runTest(self):
        # CPU time limit is a safety guard: runtime checks shouldn't take forever
        global inetRoot, executable, nedPath, cpuTimeLimit, extraOppRunArgs

        # run the simulation
        workingdir = _iif(self.wd.startswith('/'), inetRoot + "/" + self.wd, self.wd)
        wdname = '' + self.wd + ' ' + self.args
        wdname = re.sub('/', '_', wdname);
        wdname = re.sub('[\W]+', '_', wdname);
        resultdir = os.path.abspath(".") + "/results/" + self.csvFile + "/" + wdname;
        if not os.path.exists(resultdir):
            try:
                os.makedirs(resultdir)
            except OSError:
                pass
        command = executable + " -n " + nedPath + " -l " + inetLib + " -u Cmdenv " + self.args + \
            _iif(self.simtimelimit != "", " --sim-time-limit=" + self.simtimelimit, "") + \
            " --cmdenv-express-mode=true --cmdenv-performance-display=false --cmdenv-status-frequency=1000000s --cpu-time-limit=" + cpuTimeLimit + \
            " --record-eventlog=false --vector-recording=false --scalar-recording=false --bin-recording=false --param-record-as-scalar=false " + \
            " --result-dir=" + resultdir + \
            extraOppRunArgs;
        # print "COMMAND: " + command + '\n'
        measuredRuntimes = []
        for rep in range(self.repeat):
            result = self.runSimulation(self.title, command, workingdir, resultdir)

            # process the result
            measuredRuntimes.append(result.measuredRuntime)
            self.storeExitcodeCallback(result.exitcode)
            if result.exitcode != 0:
                raise Exception("runtime error:" + result.errormsg)
            elif result.cpuTimeLimitReached:
                raise Exception("cpu time limit exceeded")
            elif result.simulatedTime == 0 and self.simtimelimit != '0s':
                raise Exception("zero time simulated")

        averageMeasuredRuntime = numpy.mean(self.rejectOutliers(measuredRuntimes))
        self.storeMeasuredRuntimeCallback(averageMeasuredRuntime)
        assert self.checkMeasuredRuntime(averageMeasuredRuntime), "%f%% %f -> %f" % (100 * averageMeasuredRuntime / self.expectedRuntime, self.expectedRuntime, averageMeasuredRuntime)

    def checkMeasuredRuntime(self, measuredRuntime):
        return ((self.expectedRuntime / runtimeTolerance) < measuredRuntime) and ((runtimeTolerance * self.expectedRuntime) > measuredRuntime)

    def rejectOutliers(self, elements):
        if len(elements) < 2:
            return elements
        m = outlierThreshold
        u = numpy.mean(elements)
        s = numpy.std(elements)
        return [e for e in elements if (u - m * s < e < u + m * s)]

    def __str__(self):
        return self.title

class ThreadSafeIter:
    """Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def next(self):
        with self.lock:
            return self.it.next()



class ThreadedTestSuite(unittest.BaseTestSuite):
    """ runs toplevel tests in n threads
    """

    # How many test process at the time.
    thread_count = multiprocessing.cpu_count()

    def run(self, result):
        it = ThreadSafeIter(self.__iter__())

        result.buffered = True

        threads = []

        for i in range(self.thread_count):
            # Create self.thread_count number of threads that together will
            # cooperate removing every ip in the list. Each thread will do the
            # job as fast as it can.
            t = threading.Thread(target=self.runThread, args=(result, it))
            t.daemon = True
            t.start()
            threads.append(t)

        # Wait until all the threads are done. .join() is blocking.
        #for t in threads:
        #    t.join()
        runApp = True
        while runApp and threading.active_count() > 1:
            try:
                time.sleep(0.1)
            except KeyboardInterrupt:
                runApp = False
        return result

    def runThread(self, result, it):
        tresult = result.startThread()
        for test in it:
            if result.shouldStop:
                break
            test(tresult)
        tresult.stopThread()


class ThreadedTestResult(unittest.TestResult):
    """TestResult with threads
    """

    def __init__(self, stream=None, descriptions=None, verbosity=None):
        super(ThreadedTestResult, self).__init__()
        self.parent = None
        self.lock = threading.Lock()

    def startThread(self):
        ret = copy.copy(self)
        ret.parent = self
        return ret

    def stop():
        super(ThreadedTestResult, self).stop()
        if self.parent:
            self.parent.stop()

    def stopThread(self):
        if self.parent == None:
            return 0
        self.parent.testsRun += self.testsRun
        return 1

    def startTest(self, test):
        "Called when the given test is about to be run"
        super(ThreadedTestResult, self).startTest(test)
        self.oldstream = self.stream
        self.stream = StringIO()

    def stopTest(self, test):
        """Called when the given test has been run"""
        super(ThreadedTestResult, self).stopTest(test)
        out = self.stream.getvalue()
        with self.lock:
            self.stream = self.oldstream
            self.stream.write(out)


#
# Copy/paste of TextTestResult, with minor modifications in the output:
# we want to print the error text after ERROR and FAIL, but we don't want
# to print stack traces.
#
class SimulationTextTestResult(ThreadedTestResult):
    """A test result class that can print formatted text results to a stream.

    Used by TextTestRunner.
    """
    separator1 = '=' * 70
    separator2 = '-' * 70

    def __init__(self, stream, descriptions, verbosity):
        super(SimulationTextTestResult, self).__init__()
        self.stream = stream
        self.showAll = verbosity > 1
        self.dots = verbosity == 1
        self.descriptions = descriptions

    def getDescription(self, test):
        doc_first_line = test.shortDescription()
        if self.descriptions and doc_first_line:
            return '\n'.join((str(test), doc_first_line))
        else:
            return str(test)

    def startTest(self, test):
        super(SimulationTextTestResult, self).startTest(test)
        if self.showAll:
            self.stream.write(self.getDescription(test))
            self.stream.write(" ... ")
            self.stream.flush()

    def addSuccess(self, test):
        super(SimulationTextTestResult, self).addSuccess(test)
        measuredRuntime = test.storeMeasuredRuntimeCallback.simulation['measuredRuntime']
        if self.showAll:
            self.stream.write(": PASS (%f%%)\n" % (100 * measuredRuntime / test.expectedRuntime))
        elif self.dots:
            self.stream.write('.')
            self.stream.flush()

    def addError(self, test, err):
        # modified
        super(SimulationTextTestResult, self).addError(test, err)
        self.errors[-1] = (test, err[1])  # super class method inserts stack trace; we don't need that, so overwrite it
        if self.showAll:
            (cause, detail) = self._splitMsg(err[1])
            self.stream.write(": ERROR (%s)\n" % cause)
            if detail:
                self.stream.write(detail)
                self.stream.write("\n")
        elif self.dots:
            self.stream.write('E')
            self.stream.flush()
        global exitCode
        exitCode = 1   #"there were errors or failures"

    def addFailure(self, test, err):
        # modified
        super(SimulationTextTestResult, self).addFailure(test, err)
        self.failures[-1] = (test, err[1])  # super class method inserts stack trace; we don't need that, so overwrite it
        if self.showAll:
            (cause, detail) = self._splitMsg(err[1])
            self.stream.write(": FAIL (%s)\n" % cause)
            if detail:
                self.stream.write(detail)
                self.stream.write("\n")
        elif self.dots:
            self.stream.write('F')
            self.stream.flush()
        global exitCode
        exitCode = 1   #"there were errors or failures"

    def addSkip(self, test, reason):
        super(SimulationTextTestResult, self).addSkip(test, reason)
        if self.showAll:
            self.stream.write(": skipped {0!r}".format(reason))
            self.stream.write("\n")
        elif self.dots:
            self.stream.write("s")
            self.stream.flush()

    def addExpectedFailure(self, test, err):
        super(SimulationTextTestResult, self).addExpectedFailure(test, err)
        if self.showAll:
            self.stream.write(": expected failure\n")
        elif self.dots:
            self.stream.write("x")
            self.stream.flush()

    def addUnexpectedSuccess(self, test):
        super(SimulationTextTestResult, self).addUnexpectedSuccess(test)
        if self.showAll:
            self.stream.write(": unexpected success\n")
        elif self.dots:
            self.stream.write("u")
            self.stream.flush()

    def printErrors(self):
        # modified
        if self.dots or self.showAll:
            self.stream.write("\n")
        self.printErrorList('Errors', self.errors)
        self.printErrorList('Failures', self.failures)

    def printErrorList(self, flavour, errors):
        # modified
        if errors:
            self.stream.write("%s:\n" % flavour)
        for test, err in errors:
            self.stream.write("  %s (%s)\n" % (self.getDescription(test), self._splitMsg(err)[0]))

    def _splitMsg(self, msg):
        cause = str(msg)
        detail = None
        if cause.count(': '):
            (cause, detail) = cause.split(': ',1)
        return (cause, detail)


def _iif(cond,t,f):
    return t if cond else f

def ensure_dir(f):
    if not os.path.exists(f):
        # when using multiple threads, there is a race condition here
        try:
            os.makedirs(f) # in Python3 we will have exist_ok, which makes this whole function unnecessary
        except OSError as e:
            if e.errno != os.errno.EEXIST: # ignore "OSError: [Errno 17] File exists"
                raise  # re-raise the exception if the error is something else

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run the tests specified in the input files.')
    parser.add_argument('testspecfiles', nargs='*', metavar='testspecfile', help='CSV files that contain the tests to run (default: *.csv). Expected CSV file columns: workingdir, args, simtimelimit, runtime')
    parser.add_argument('-m', '--match', action='append', metavar='regex', help='Line filter: a line (more precisely, workingdir+SPACE+args) must match any of the regular expressions in order for that test case to be run')
    parser.add_argument('-x', '--exclude', action='append', metavar='regex', help='Negative line filter: a line (more precisely, workingdir+SPACE+args) must NOT match any of the regular expressions in order for that test case to be run')
    parser.add_argument('-t', '--threads', type=int, default=1, help='number of parallel threads (default: number of CPUs to use)')
    parser.add_argument('-r', '--repeat', type=int, default=10, help='number of repeating each test (default: 10)')
    parser.add_argument('-a', '--oppargs', action='append', metavar='oppargs', nargs='*', help='extra opp_run arguments without leading "--", you can terminate list with " -- " ')
    parser.add_argument('-e', '--executable', help='Determines which binary to execute (e.g. opp_run_dbg, opp_run_release)')
    parser.add_argument('-l', '--runtimeTolerance', type=float, default=1.01, help='Specifies the maximum allowed relative runtime difference (default: 1%%)')
    parser.add_argument('-o', '--outlierThreshold', type=float, default=2.0, help='Specifies the threshold for removing outliers from the runtime measurements (default: 2 * stddev)')
    args = parser.parse_args()

    if os.path.isfile(logFile):
        FILE = open(logFile, "w")
        FILE.close()

    if not args.testspecfiles:
        args.testspecfiles = glob.glob('*.csv')

    if args.oppargs:
        for oppArgList in args.oppargs:
            for oppArg in oppArgList:
                extraOppRunArgs += " --" + oppArg

    if args.executable:
        executable = args.executable

    runtimeTolerance = args.runtimeTolerance
    outlierThreshold = args.outlierThreshold

    generator = runtimeTestCaseGenerator()
    testcases = generator.generateFromCSV(args.testspecfiles, args.match, args.exclude, args.repeat)

    testSuite = ThreadedTestSuite()
    testSuite.addTests(testcases)
    testSuite.thread_count = args.threads
    testSuite.repeat = args.repeat

    testRunner = unittest.TextTestRunner(stream=sys.stdout, verbosity=9, resultclass=SimulationTextTestResult)

    disabledCPUs = range(args.threads, multiprocessing.cpu_count())
    for cpu in disabledCPUs:
        os.system("echo 0 | sudo tee /sys/devices/system/cpu/cpu%d/online > /dev/null" % (cpu))

    testRunner.run(testSuite)

    for cpu in disabledCPUs:
        os.system("echo 1 | sudo tee /sys/devices/system/cpu/cpu%d/online > /dev/null" % (cpu))

    print
    generator.writeUpdatedCSVFiles()
    generator.writeErrorCSVFiles()
    generator.writeFailedCSVFiles()

    print "Log has been saved to %s" % logFile

    exit(exitCode)

